{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 0.2942495 ,  0.95572865, -0.12561874], dtype=float32)"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import gym\n",
    "\n",
    "\n",
    "# 定义环境\n",
    "class MyWrapper(gym.Wrapper):\n",
    "    def __init__(self):\n",
    "        env = gym.make(\"Pendulum-v1\", render_mode=\"rgb_array\")\n",
    "        super().__init__(env)\n",
    "        self.env = env\n",
    "        self.step_n = 0\n",
    "\n",
    "    def reset(self):\n",
    "        state, _ = self.env.reset()\n",
    "        self.step_n = 0\n",
    "        return state\n",
    "\n",
    "    def step(self, action):\n",
    "        state, reward, terminated, truncated, info = self.env.step(action)\n",
    "        done = terminated or truncated\n",
    "        self.step_n += 1\n",
    "        if self.step_n >= 200:\n",
    "            done = True\n",
    "        return state, reward, done, info\n",
    "\n",
    "\n",
    "env = MyWrapper()\n",
    "\n",
    "env.reset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Box(-2.0, 2.0, (1,), float32)"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "env.action_space"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAakAAAGiCAYAAABd6zmYAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAkSElEQVR4nO3df3DU9YH/8df+yG5+7oYEsiGQSAQU0/BDE8Btb669I5L2Mj2t3Ix1GMt5Xh1p5EQ63shdpdN+ewdj73ttvVPauU7FP65yw/XQykHbXNBQS4QQQCFA/FEkUdiEH2Y3CWQ3yb6/f3Ds11W0SUiy74TnY2ZnzOfz3t33523cp7v72Y3DGGMEAICFnKmeAAAAn4RIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCslbJIPf3005o1a5bS09O1dOlS7d+/P1VTAQBYKiWR+o//+A+tW7dO3/72t3Xw4EEtXLhQ1dXV6uzsTMV0AACWcqTiC2aXLl2qxYsX61//9V8lSfF4XMXFxVqzZo0ef/zx8Z4OAMBS7vG+w1gspubmZq1fvz6xzel0qqqqSo2NjVe9TjQaVTQaTfwcj8d14cIF5efny+FwjPmcAQCjyxij7u5uFRUVyen85Bf1xj1S586d0+DgoAKBQNL2QCCgEydOXPU6Gzdu1He+853xmB4AYBy1t7dr5syZn7h/3CM1EuvXr9e6desSP4fDYZWUlKi9vV0+ny+FMwMAjEQkElFxcbFycnI+ddy4R2rq1KlyuVzq6OhI2t7R0aHCwsKrXsfr9crr9X5su8/nI1IAMIH9obdsxv3sPo/Ho4qKCtXX1ye2xeNx1dfXKxgMjvd0AAAWS8nLfevWrdOqVatUWVmpJUuW6Ic//KF6e3t1//33p2I6AABLpSRS99xzj86ePasNGzYoFApp0aJF+tWvfvWxkykAANe3lHxO6lpFIhH5/X6Fw2HekwKACWioj+N8dx8AwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaw07Unv27NGXv/xlFRUVyeFw6IUXXkjab4zRhg0bNH36dGVkZKiqqkpvvfVW0pgLFy5o5cqV8vl8ys3N1QMPPKCenp5rOhAAwOQz7Ej19vZq4cKFevrpp6+6/8knn9RTTz2lH//4x9q3b5+ysrJUXV2tvr6+xJiVK1eqpaVFdXV12rFjh/bs2aMHH3xw5EcBAJiczDWQZLZv3574OR6Pm8LCQvP9738/sa2rq8t4vV7z/PPPG2OMOXbsmJFkmpqaEmN27dplHA6Hef/994d0v+Fw2Egy4XD4WqYPAEiRoT6Oj+p7UidPnlQoFFJVVVVim9/v19KlS9XY2ChJamxsVG5uriorKxNjqqqq5HQ6tW/fvqvebjQaVSQSSboAACa/UY1UKBSSJAUCgaTtgUAgsS8UCqmgoCBpv9vtVl5eXmLMR23cuFF+vz9xKS4uHs1pAwAsNSHO7lu/fr3C4XDi0t7enuopAQDGwahGqrCwUJLU0dGRtL2joyOxr7CwUJ2dnUn7BwYGdOHChcSYj/J6vfL5fEkXAMDkN6qRKi0tVWFhoerr6xPbIpGI9u3bp2AwKEkKBoPq6upSc3NzYszu3bsVj8e1dOnS0ZwOAGCCcw/3Cj09PXr77bcTP588eVKHDx9WXl6eSkpKtHbtWn3ve9/T3LlzVVpaqieeeEJFRUW66667JEm33HKLvvjFL+rrX/+6fvzjH6u/v18PP/ywvvrVr6qoqGjUDgwAMAkM97TBl19+2Uj62GXVqlXGmMunoT/xxBMmEAgYr9drli1bZlpbW5Nu4/z58+bee+812dnZxufzmfvvv990d3eP+qmLAAA7DfVx3GGMMSls5IhEIhH5/X6Fw2HenwKACWioj+MT4uw+AMD1iUgBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGCtYUVq48aNWrx4sXJyclRQUKC77rpLra2tSWP6+vpUW1ur/Px8ZWdna8WKFero6Ega09bWppqaGmVmZqqgoECPPfaYBgYGrv1oAACTyrAi1dDQoNraWr322muqq6tTf3+/li9frt7e3sSYRx99VC+99JK2bdumhoYGnT59WnfffXdi/+DgoGpqahSLxbR3714999xz2rJlizZs2DB6RwUAmBzMNejs7DSSTENDgzHGmK6uLpOWlma2bduWGHP8+HEjyTQ2NhpjjNm5c6dxOp0mFAolxmzevNn4fD4TjUaHdL/hcNhIMuFw+FqmDwBIkaE+jl/Te1LhcFiSlJeXJ0lqbm5Wf3+/qqqqEmPmzZunkpISNTY2SpIaGxs1f/58BQKBxJjq6mpFIhG1tLRc9X6i0agikUjSBQAw+Y04UvF4XGvXrtXnPvc5lZeXS5JCoZA8Ho9yc3OTxgYCAYVCocSYDwfqyv4r+65m48aN8vv9iUtxcfFIpw0AmEBGHKna2lodPXpUW7duHc35XNX69esVDocTl/b29jG/TwBA6rlHcqWHH35YO3bs0J49ezRz5szE9sLCQsViMXV1dSU9m+ro6FBhYWFizP79+5Nu78rZf1fGfJTX65XX6x3JVDGOBi9dUqyjQ/3hsOKxmBwulxxOpxxut9KLiuTOzZXDyaceAAzdsCJljNGaNWu0fft2vfLKKyotLU3aX1FRobS0NNXX12vFihWSpNbWVrW1tSkYDEqSgsGg/uEf/kGdnZ0qKCiQJNXV1cnn86msrGw0jgnjzAwOquf4cV3Ys0e9x48r2tmpeF+fHG735Uh5PMq66SZlzp6t7FtuUda8eXJnZaV62gAmAIcxxgx18De+8Q39/Oc/14svvqibb745sd3v9ysjI0OStHr1au3cuVNbtmyRz+fTmjVrJEl79+6VdPkU9EWLFqmoqEhPPvmkQqGQ7rvvPv31X/+1/vEf/3FI84hEIvL7/QqHw/L5fEM+WIy+eCymzl271PGLX2ggEpHi8U8e7HLJlZUlT16essvLlbtkiTJvvFHO9HQ50tLkcDjGb+IAUmqoj+PDitQnPYg8++yz+su//EtJlz/M+81vflPPP/+8otGoqqur9cwzzyS9lHfq1CmtXr1ar7zyirKysrRq1Spt2rRJbvfQntgRKTsMXryojl/+Uh2/+IXi0eiIbsPt9yv39tuVXVamjFmz5J0+XU6vl2ABk9yYRMoWRCr1TDyuD377W7X9279pcDQ+EuB0KuOGG5Q+Y4ayy8qUs3Ch0ouK5HC5rv22AVhnqI/jIzpxAhiIRNT2k59osKdndG4wHtelkyd16eRJde3fL1dGhrwzZmhKMCjfokVKy8uTMyNDziE+2wYwOfBfPEbk3P/8z+gF6iNMLKaBWEwD4bB6jx2TMz1dmbNnK2fhQmWWlipj1ix5Cgp4SRC4DhApjMgHDQ3jdl/xvj71tLSop6VFruxspRcXK6OkRL7bblPOggVyZWYmxhIuYHIhUphQBnt61Hv8uHpPnNCFhgY509Plr6iQf8kSZZaWyu3zXT5bkM9jAZMCkcLEZIzifX2K9/XpfH29ztfXK23aNOWUl1/+TNaNNyqjpEQuPo8FTGhECpNG/9mzuvDyy/rg1VflmTpVnoICZd18s6YEg8ooLZX+96VAXhIEJg4ihRFJmzpVl06dSvU0rsr09yt65oyiZ86o+8gRdb74ojwFBcq9/XblLlmitPx8uXNy5PB4CBZgOT4nhRG5dOqUjv3N30gT7NfH4XYrc84cZd10k7JuvlmZs2fLW1AgB6e2A+OKz0lhTHmnT9fUO+7Qud/8JtVTGRYzMKDeEyfUe+KEXDk58uTnK72kRLmLFytn/ny5/X7J6eQZFmAJIoURcaSlKfAXf6H+ri6FDxz49O/ss9Rgd7cudXfr0rvvqut3v5NcLuXMn68pn/2sMmfPVlpentx+P8ECUohIYUQcDofSCws142tfk4nHFTlwINVTuiZmcFAaHFSkuVmR5ma5p0xJvCSYNWeOMufOlSszk2AB44z3pHDNop2dOrtzp87v3q2Brq5UT2dUOVwuuf1+peXnK2f+fPkqK5V1441yejySy0W0gBHiC2YxbowxkjGKnT2r9372M/W1tyva2SkTi6V6aqPL4ZAcDrkyM5X3J3+inM98RuklJfJMnSpXenqqZwdMKEQKKWEGBnTp1Cn1vv22Lr7zjnpPnFDfe+/JDAykempjIuPGG5VZWqqsuXOVPX++0mfM4NsugCEgUkgp87/fCDHQ3a1YR4e6mpoUaW5W7OxZxfv7pcHBVE9xVDnT0+X2+eSdMUP+xYuVW1kpd26unGlpnC0IXAWRgjWu/IqZgQFdevddhZua1NvaqmhHh6KnT6d4dmPD6fUq86ab5Lv1VmXeeKPSi4vlyc/nWRbwv4gUrDbQ26u+9nb1trbq4smT6jl2TLFQKNXTGhOuzExlzpmj9Jkz5auoUPYtt8idnZ3qaQEpRaQwIZh4XPG+Pg329uriyZMKNzUpcvCgBnp6Lv9J+gn4+atP48rOliszUzmLFl3+PNasWXJmZsrp8fAsC9cVIoUJ5cO/hvFoVD0tLYocPqxL776rS6dOTbpT269ImzpVvkWLlP2ZzyijuFjpM2bwze24LhApTHjGGPVfuKBLp06pr71d3UeOqOf4cQ12d6d6aqPOkZYmbyAg74wZypw9W/7KSmXOnp08hpMvMIkQKUwqZnBQ8WhUAz09l//g4auv6uLbbyt+8aIGL15M9fRGlcPtljMjQ06PR87MTGXPm6e8P/5jZc6eLVdWFrHCpECkMOn1f/CBultaFDl4UNEzZ9T33nsaCIdTPa0x4fB45L/tNhXcdZeyb7mFUGHCI1K4bsQHBtR//ryiZ86o9803FXnjDfUePy7T35/qqY269BtuUMmDDyq7vJxQYUIjUrjuGGOkwUHF+/s1EA6rq6lJ4f37FT19WgORyOWzBSeB9BtuUMnq1TyjwoRGpABdDlffe++p+8gRXXz77cRJGPG+vlRP7RO939urQxcuqLu/X9PS0xWcNk1ZaWlJY3IWLtSNf/u3cufkpGiWwLXhjx4CunxGXEZxsTKKixWPRhU7e1bRUEg9x48r3NSkS21tl/+6sAX/r2aM0cmeHn370CG929OjvsFB+dLSVD5liv5p8WKlfehzVN2vv66L77yjnIULeTaFSY1nUrjuGGOkeFzx/n5FQyGFm5rU9dprGgiHFTt/PmXfK/hOd7ce/N3vFL7Ke2lLpk7V/7n1VuV/6NvWvUVF+szmzUQKExLPpIBP4HA4JJdLLpdLmbNmKXPWLAXuvFMX331X3a+/rr72dl38/e/V9/774xqsH7a0XDVQkrT/3DnVnT6tr954Y2LbZDwxBPgoIgVIcno8yr7pJmXNnavBixc18MEH6jt9OvE1Tf0ffHD5r/dOvBcegAmNSAEf4nA45M7Kkjsr6/Kf3aiokInH1dvaqg/27tXFt95S7Px5DXR1Tdq/kQXYhEgBn+DKy4IOl0s55eXKKS/XQHe3Lv7+97r4zju6+Pbb6jlxQv3nzo3K/dUUF+vAuXPqv8qztVnZ2VqQlzcq9wNMJEQKGAZ3To58Cxcqp7xcgxcvqv+DDy7/jawDBxRualI8Fhvxe0XVRUWSpO+9/rpig4OKS3I5HMr1ePR/Fy/WDR/58x6BFSuu9XAA6xEpYAQcLpfcOTlyZWcrvbhYU/7ojxS/dEndR4/qg9/9Tn2nT6v/7Fn1d3UN+X0sh8Oh6qIizczM1I733tP5vj7Nys7WPaWlyvd6k8Z6p0/XlGCQM/sw6REp4BokIuFwyJWVpdylS+VfskSxc+fUd+rU5T/oePSoet98U4O9vUO6vfIpU1Q+ZconjnFPmaKi++6Tm49f4DpApIBR5nA45J02Td5p0+RbtEgDy5droLtbPUeP6sKrr+rS739/+WXBWGzYt+3Kztb0e+5R7pIlcrhcYzB7wC5EChhDDrdbaX6/3D6f0mfM0NTqavVfuKDIwYMKHzig2Pnzip4+rcGenk+/HZdL6cXFCnzlK8r7whd4mQ/XDSIFjIMPR8WTn6+pd9yhvC98QbGzZ3Xp1CldfOstdR87pr5Tp+SZNk3RUEjxaPTye14zZ8pfWSl/ZaUySkoIFK4rRApIEWdamtKLiuSdPl3+ykoFLl1SPBaTw+2WGRiQicflcLnkTEuTMzNTTjf/ueL6w289kGIOh0OOtDQ5P/JN5wAk5x8eAgBAahApAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYaVqQ2b96sBQsWyOfzyefzKRgMateuXYn9fX19qq2tVX5+vrKzs7VixQp1dHQk3UZbW5tqamqUmZmpgoICPfbYYxoYGBidowEATCrDitTMmTO1adMmNTc368CBA/rTP/1T3XnnnWppaZEkPfroo3rppZe0bds2NTQ06PTp07r77rsT1x8cHFRNTY1isZj27t2r5557Tlu2bNGGDRtG96gAAJODuUZTpkwxP/3pT01XV5dJS0sz27ZtS+w7fvy4kWQaGxuNMcbs3LnTOJ1OEwqFEmM2b95sfD6fiUajQ77PcDhsJJlwOHyt0wcApMBQH8dH/J7U4OCgtm7dqt7eXgWDQTU3N6u/v19VVVWJMfPmzVNJSYkaGxslSY2NjZo/f74CgUBiTHV1tSKRSOLZ2NVEo1FFIpGkCwBg8ht2pI4cOaLs7Gx5vV499NBD2r59u8rKyhQKheTxeJSbm5s0PhAIKBQKSZJCoVBSoK7sv7Lvk2zcuFF+vz9xKS4uHu60AQAT0LAjdfPNN+vw4cPat2+fVq9erVWrVunYsWNjMbeE9evXKxwOJy7t7e1jen8AADu4h3sFj8ejOXPmSJIqKirU1NSkH/3oR7rnnnsUi8XU1dWV9Gyqo6NDhYWFkqTCwkLt378/6faunP13ZczVeL1eeb3e4U4VADDBXfPnpOLxuKLRqCoqKpSWlqb6+vrEvtbWVrW1tSkYDEqSgsGgjhw5os7OzsSYuro6+Xw+lZWVXetUAACTzLCeSa1fv15f+tKXVFJSou7ubv385z/XK6+8ol//+tfy+/164IEHtG7dOuXl5cnn82nNmjUKBoO6/fbbJUnLly9XWVmZ7rvvPj355JMKhUL61re+pdraWp4pAQA+ZliR6uzs1Ne+9jWdOXNGfr9fCxYs0K9//WvdcccdkqQf/OAHcjqdWrFihaLRqKqrq/XMM88kru9yubRjxw6tXr1awWBQWVlZWrVqlb773e+O7lEBACYFhzHGpHoSwxWJROT3+xUOh+Xz+VI9HQDAMA31cZzv7gMAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgrWuK1KZNm+RwOLR27drEtr6+PtXW1io/P1/Z2dlasWKFOjo6kq7X1tammpoaZWZmqqCgQI899pgGBgauZSoAgEloxJFqamrST37yEy1YsCBp+6OPPqqXXnpJ27ZtU0NDg06fPq277747sX9wcFA1NTWKxWLau3evnnvuOW3ZskUbNmwY+VEAACYnMwLd3d1m7ty5pq6uznz+8583jzzyiDHGmK6uLpOWlma2bduWGHv8+HEjyTQ2NhpjjNm5c6dxOp0mFAolxmzevNn4fD4TjUaHdP/hcNhIMuFweCTTBwCk2FAfx0f0TKq2tlY1NTWqqqpK2t7c3Kz+/v6k7fPmzVNJSYkaGxslSY2NjZo/f74CgUBiTHV1tSKRiFpaWq56f9FoVJFIJOkCAJj83MO9wtatW3Xw4EE1NTV9bF8oFJLH41Fubm7S9kAgoFAolBjz4UBd2X9l39Vs3LhR3/nOd4Y7VQDABDesZ1Lt7e165JFH9O///u9KT08fqzl9zPr16xUOhxOX9vb2cbtvAEDqDCtSzc3N6uzs1G233Sa32y23262GhgY99dRTcrvdCgQCisVi6urqSrpeR0eHCgsLJUmFhYUfO9vvys9XxnyU1+uVz+dLugAAJr9hRWrZsmU6cuSIDh8+nLhUVlZq5cqViX9OS0tTfX194jqtra1qa2tTMBiUJAWDQR05ckSdnZ2JMXV1dfL5fCorKxulwwIATAbDek8qJydH5eXlSduysrKUn5+f2P7AAw9o3bp1ysvLk8/n05o1axQMBnX77bdLkpYvX66ysjLdd999evLJJxUKhfStb31LtbW18nq9o3RYAIDJYNgnTvwhP/jBD+R0OrVixQpFo1FVV1frmWeeSex3uVzasWOHVq9erWAwqKysLK1atUrf/e53R3sqAIAJzmGMMamexHBFIhH5/X6Fw2HenwKACWioj+N8dx8AwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFruVE9gJIwxkqRIJJLimQAARuLK4/eVx/NPMiEjdf78eUlScXFximcCALgW3d3d8vv9n7h/QkYqLy9PktTW1vapB3e9i0QiKi4uVnt7u3w+X6qnYy3WaWhYp6FhnYbGGKPu7m4VFRV96rgJGSmn8/JbaX6/n1+CIfD5fKzTELBOQ8M6DQ3r9IcN5UkGJ04AAKxFpAAA1pqQkfJ6vfr2t78tr9eb6qlYjXUaGtZpaFinoWGdRpfD/KHz/wAASJEJ+UwKAHB9IFIAAGsRKQCAtYgUAMBaEzJSTz/9tGbNmqX09HQtXbpU+/fvT/WUxtWePXv05S9/WUVFRXI4HHrhhReS9htjtGHDBk2fPl0ZGRmqqqrSW2+9lTTmwoULWrlypXw+n3Jzc/XAAw+op6dnHI9ibG3cuFGLFy9WTk6OCgoKdNddd6m1tTVpTF9fn2pra5Wfn6/s7GytWLFCHR0dSWPa2tpUU1OjzMxMFRQU6LHHHtPAwMB4HsqY2rx5sxYsWJD44GkwGNSuXbsS+1mjq9u0aZMcDofWrl2b2MZajREzwWzdutV4PB7zs5/9zLS0tJivf/3rJjc313R0dKR6auNm586d5u///u/Nf/3XfxlJZvv27Un7N23aZPx+v3nhhRfM66+/bv78z//clJaWmkuXLiXGfPGLXzQLFy40r732mvntb39r5syZY+69995xPpKxU11dbZ599llz9OhRc/jwYfNnf/ZnpqSkxPT09CTGPPTQQ6a4uNjU19ebAwcOmNtvv9189rOfTewfGBgw5eXlpqqqyhw6dMjs3LnTTJ061axfvz4VhzQmfvnLX5r//u//Nm+++aZpbW01f/d3f2fS0tLM0aNHjTGs0dXs37/fzJo1yyxYsMA88sgjie2s1diYcJFasmSJqa2tTfw8ODhoioqKzMaNG1M4q9T5aKTi8bgpLCw03//+9xPburq6jNfrNc8//7wxxphjx44ZSaapqSkxZteuXcbhcJj3339/3OY+njo7O40k09DQYIy5vCZpaWlm27ZtiTHHjx83kkxjY6Mx5vL/DDidThMKhRJjNm/ebHw+n4lGo+N7AONoypQp5qc//SlrdBXd3d1m7ty5pq6uznz+859PRIq1GjsT6uW+WCym5uZmVVVVJbY5nU5VVVWpsbExhTOzx8mTJxUKhZLWyO/3a+nSpYk1amxsVG5uriorKxNjqqqq5HQ6tW/fvnGf83gIh8OS/v+XEzc3N6u/vz9pnebNm6eSkpKkdZo/f74CgUBiTHV1tSKRiFpaWsZx9uNjcHBQW7duVW9vr4LBIGt0FbW1taqpqUlaE4nfp7E0ob5g9ty5cxocHEz6lyxJgUBAJ06cSNGs7BIKhSTpqmt0ZV8oFFJBQUHSfrfbrby8vMSYySQej2vt2rX63Oc+p/LyckmX18Dj8Sg3Nzdp7EfX6WrreGXfZHHkyBEFg0H19fUpOztb27dvV1lZmQ4fPswafcjWrVt18OBBNTU1fWwfv09jZ0JFChiJ2tpaHT16VK+++mqqp2Klm2++WYcPH1Y4HNZ//ud/atWqVWpoaEj1tKzS3t6uRx55RHV1dUpPT0/1dK4rE+rlvqlTp8rlcn3sjJmOjg4VFhamaFZ2ubIOn7ZGhYWF6uzsTNo/MDCgCxcuTLp1fPjhh7Vjxw69/PLLmjlzZmJ7YWGhYrGYurq6ksZ/dJ2uto5X9k0WHo9Hc+bMUUVFhTZu3KiFCxfqRz/6EWv0Ic3Nzers7NRtt90mt9stt9uthoYGPfXUU3K73QoEAqzVGJlQkfJ4PKqoqFB9fX1iWzweV319vYLBYApnZo/S0lIVFhYmrVEkEtG+ffsSaxQMBtXV1aXm5ubEmN27dysej2vp0qXjPuexYIzRww8/rO3bt2v37t0qLS1N2l9RUaG0tLSkdWptbVVbW1vSOh05ciQp6HV1dfL5fCorKxufA0mBeDyuaDTKGn3IsmXLdOTIER0+fDhxqays1MqVKxP/zFqNkVSfuTFcW7duNV6v12zZssUcO3bMPPjggyY3NzfpjJnJrru72xw6dMgcOnTISDL//M//bA4dOmROnTpljLl8Cnpubq558cUXzRtvvGHuvPPOq56Cfuutt5p9+/aZV1991cydO3dSnYK+evVq4/f7zSuvvGLOnDmTuFy8eDEx5qGHHjIlJSVm9+7d5sCBAyYYDJpgMJjYf+WU4eXLl5vDhw+bX/3qV2batGmT6pThxx9/3DQ0NJiTJ0+aN954wzz++OPG4XCY3/zmN8YY1ujTfPjsPmNYq7Ey4SJljDH/8i//YkpKSozH4zFLliwxr732WqqnNK5efvllI+ljl1WrVhljLp+G/sQTT5hAIGC8Xq9ZtmyZaW1tTbqN8+fPm3vvvddkZ2cbn89n7r//ftPd3Z2CoxkbV1sfSebZZ59NjLl06ZL5xje+YaZMmWIyMzPNV77yFXPmzJmk23n33XfNl770JZORkWGmTp1qvvnNb5r+/v5xPpqx81d/9VfmhhtuMB6Px0ybNs0sW7YsEShjWKNP89FIsVZjgz/VAQCw1oR6TwoAcH0hUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFr/D/ZC4d4eYta1AAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "\n",
    "%matplotlib inline\n",
    "\n",
    "\n",
    "#打印游戏\n",
    "def show():\n",
    "    plt.imshow(env.render())\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(-1.7949174642562866, -1201.5130408316786)"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "from IPython import display\n",
    "import random\n",
    "\n",
    "\n",
    "class Model(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.fc_statu = torch.nn.Sequential(\n",
    "            torch.nn.Linear(3, 128),\n",
    "            torch.nn.ReLU(),\n",
    "        )\n",
    "\n",
    "        self.fc_mu = torch.nn.Sequential(\n",
    "            torch.nn.Linear(128, 1),\n",
    "            torch.nn.Tanh(),  # 这里经过Tanh输出是-1，1\n",
    "        )\n",
    "\n",
    "        self.fc_std = torch.nn.Sequential(\n",
    "            torch.nn.Linear(128, 1),\n",
    "            torch.nn.Softplus(),\n",
    "        )\n",
    "\n",
    "    def forward(self, state):\n",
    "        state = self.fc_statu(state)\n",
    "\n",
    "        mu = self.fc_mu(state) * 2.0  # *2代表把[-1~1] --> [-2~2]\n",
    "        std = self.fc_std(state)\n",
    "        # print(f'mu:{mu.item()},std:{std.item()}')\n",
    "\n",
    "        return mu, std\n",
    "\n",
    "\n",
    "class PPO:\n",
    "    def __init__(self):\n",
    "        self.model = Model().to(device=\"cuda\")  # 输入状态，返回一个正态分布\n",
    "        self.model_td = torch.nn.Sequential(\n",
    "            torch.nn.Linear(3, 128),\n",
    "            torch.nn.ReLU(),\n",
    "            torch.nn.Linear(128, 1),\n",
    "        ).to(device='cuda')  # 根据状态得到一个分数\n",
    "\n",
    "        # 定义优化器\n",
    "        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-4)\n",
    "        self.optimizer_td = torch.optim.Adam(self.model_td.parameters(), lr=5e-3)\n",
    "        self.loss_fn = torch.nn.MSELoss()\n",
    "\n",
    "    def get_action(self, state):  # 根据状态得到一个动作\n",
    "        state = torch.FloatTensor(state).reshape(1, 3).to(device=\"cuda\")\n",
    "        mu, std = self.model(state)\n",
    "\n",
    "        action = torch.distributions.Normal(mu, std).sample().item()\n",
    "        return action\n",
    "\n",
    "    def get_data(self):\n",
    "        states = []\n",
    "        rewards = []\n",
    "        actions = []\n",
    "        next_states = []\n",
    "        overs = []\n",
    "\n",
    "        # 初始化游戏\n",
    "        state = env.reset()\n",
    "        # 玩到游戏结束为止\n",
    "\n",
    "        over = False\n",
    "        while not over:\n",
    "            action = self.get_action(state)\n",
    "\n",
    "            # 执行动作得到反馈\n",
    "            next_state, reward, over, info = env.step([action])\n",
    "\n",
    "            # 记录样本数据\n",
    "            states.append(state)\n",
    "            rewards.append(reward)\n",
    "            actions.append(action)\n",
    "            next_states.append(next_state)\n",
    "            overs.append(over)\n",
    "\n",
    "            # 更新状态\n",
    "            state = next_state\n",
    "\n",
    "        states = torch.FloatTensor(states).reshape(-1, 3).to(device=\"cuda\")\n",
    "        rewards = torch.FloatTensor(rewards).reshape(-1, 1).to(device=\"cuda\")\n",
    "        actions = torch.FloatTensor(actions).reshape(-1, 1).to(device=\"cuda\")\n",
    "        next_states = torch.FloatTensor(next_states).reshape(-1, 3).to(device=\"cuda\")\n",
    "        overs = torch.LongTensor(overs).reshape(-1, 1).to(device=\"cuda\")\n",
    "\n",
    "        return states, rewards, actions, next_states, overs\n",
    "\n",
    "    def test(self, play):\n",
    "        state = env.reset()\n",
    "\n",
    "        reward_sum = 0\n",
    "\n",
    "        over = False\n",
    "        while not over:\n",
    "            action = self.get_action(state)\n",
    "\n",
    "            # 执行动作\n",
    "            state, reward, over, _ = env.step([action])\n",
    "\n",
    "            reward_sum += reward\n",
    "\n",
    "            if play and random.random() < 0.2:\n",
    "                display.clear_output(wait=True)\n",
    "                show()\n",
    "        return reward_sum\n",
    "\n",
    "    def get_advantages(self, deltas):\n",
    "        advantages = []\n",
    "        s = 0.0\n",
    "\n",
    "        for dalta in deltas[::-1]:\n",
    "            s = 0.9 * 0.9 * s + dalta\n",
    "            advantages.append(s)\n",
    "\n",
    "        advantages.reverse()\n",
    "\n",
    "        return advantages\n",
    "\n",
    "    def train(self, states, rewards, actions, next_states, overs):\n",
    "        rewards = (rewards + 8) / 8\n",
    "\n",
    "        values = self.model_td(states)\n",
    "\n",
    "        targets = self.model_td(next_states).detach()\n",
    "        targets = targets * 0.98\n",
    "        targets *= 1 - overs\n",
    "        targets += rewards\n",
    "\n",
    "        deltas = (targets - values).squeeze(dim=1).tolist()\n",
    "        acvantages = self.get_advantages(deltas)\n",
    "        acvantages = torch.FloatTensor(acvantages).reshape(-1, 1).to(device=\"cuda\")\n",
    "\n",
    "        mu, std = self.model(states)\n",
    "\n",
    "        old_probs = torch.distributions.Normal(mu, std)\n",
    "        old_probs = old_probs.log_prob(actions).exp().detach()\n",
    "\n",
    "        # 每批数据训练10次\n",
    "        for _ in range(10):\n",
    "            mu, std = self.model(states)\n",
    "\n",
    "            new_probs = torch.distributions.Normal(mu, std)\n",
    "            new_probs = new_probs.log_prob(actions).exp()\n",
    "\n",
    "            ratios = new_probs / old_probs\n",
    "\n",
    "            surr1 = ratios * acvantages\n",
    "\n",
    "            surr2 = torch.clamp(ratios, 0.8, 1.2) * acvantages\n",
    "\n",
    "            loss = -torch.min(surr1,surr2)\n",
    "            loss = loss.mean()\n",
    "\n",
    "            values = self.model_td(states)\n",
    "            loss_td = self.loss_fn(values,targets)\n",
    "\n",
    "            #更新参数\n",
    "            self.optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            self.optimizer.step()\n",
    "\n",
    "            self.optimizer_td.zero_grad()\n",
    "            loss_td.backward()\n",
    "            self.optimizer_td.step()\n",
    "\n",
    "teacher = PPO()\n",
    "teacher.train(*teacher.get_data())\n",
    "teacher.get_action([1,2,3]),teacher.test(play=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAakAAAGiCAYAAABd6zmYAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAo8UlEQVR4nO3dfXRU9YH/8c9MJjN5nAkJZAJChP2JKI1o5SGMra2tqamlD2ps0cNRqlTFRlalx13xAVe7K1Z/rS0uD57TU3UflC7tggVBmwUN7RoDRmgDKlqblghMgsTMJIFMJjPf3x+W+RlFSMgk8834fp0z55h7vzPznSvkzb1z547DGGMEAICFnKmeAAAAn4RIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCslbJIrVixQhMnTlRWVpbKy8u1ffv2VE0FAGCplETql7/8pRYvXqz77rtPr732ms4991xVVlaqtbU1FdMBAFjKkYoLzJaXl2vmzJn613/9V0lSPB7XhAkTtGjRIt15553DPR0AgKVcw/2EPT09amho0JIlSxLLnE6nKioqVFdXd9z7RCIRRSKRxM/xeFxtbW0qKiqSw+EY8jkDAJLLGKOOjg6NGzdOTucnH9Qb9ki99957isVi8vv9fZb7/X69+eabx73PsmXLdP/99w/H9AAAw6i5uVnjx4//xPXDHqlTsWTJEi1evDjxcygUUmlpqZqbm+X1elM4MwDAqQiHw5owYYLy8/NPOG7YIzV69GhlZGSopaWlz/KWlhaVlJQc9z4ej0cej+djy71eL5ECgBHsZG/ZDPvZfW63W9OnT9eWLVsSy+LxuLZs2aJAIDDc0wEAWCwlh/sWL16s+fPna8aMGZo1a5Z++tOfqqurS9ddd10qpgMAsFRKIjV37lwdOnRIS5cuVTAY1Hnnnafnn3/+YydTAAA+3VLyOanBCofD8vl8CoVCvCcFACNQf3+Pc+0+AIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYacKS2bdumb3zjGxo3bpwcDofWr1/fZ70xRkuXLtXYsWOVnZ2tiooKvf32233GtLW1ad68efJ6vSooKNCCBQvU2dk5qBcCAEg/A45UV1eXzj33XK1YseK46x9++GEtX75cq1evVn19vXJzc1VZWanu7u7EmHnz5mnPnj2qqanRxo0btW3bNt14442n/ioAAOnJDIIks27dusTP8XjclJSUmEceeSSxrL293Xg8HvPMM88YY4x5/fXXjSSzY8eOxJjNmzcbh8Nh9u/f36/nDYVCRpIJhUKDmT4AIEX6+3s8qe9JNTU1KRgMqqKiIrHM5/OpvLxcdXV1kqS6ujoVFBRoxowZiTEVFRVyOp2qr68/7uNGIhGFw+E+NwBA+ktqpILBoCTJ7/f3We73+xPrgsGgiouL+6x3uVwqLCxMjPmoZcuWyefzJW4TJkxI5rQBAJYaEWf3LVmyRKFQKHFrbm5O9ZQAAMMgqZEqKSmRJLW0tPRZ3tLSklhXUlKi1tbWPut7e3vV1taWGPNRHo9HXq+3zw0AkP6SGqlJkyappKREW7ZsSSwLh8Oqr69XIBCQJAUCAbW3t6uhoSExZuvWrYrH4yovL0/mdAAAI5xroHfo7OzUn/70p8TPTU1N2rVrlwoLC1VaWqrbbrtN//zP/6zJkydr0qRJuvfeezVu3DhddtllkqSzzz5bX/3qV3XDDTdo9erVikajuuWWW3TVVVdp3LhxSXthAIA0MNDTBl988UUj6WO3+fPnG2M+OA393nvvNX6/33g8HnPxxRebvXv39nmMw4cPm6uvvtrk5eUZr9drrrvuOtPR0ZH0UxcBAHbq7+9xhzHGpLCRpyQcDsvn8ykUCvH+FACMQP39PT4izu4DAHw6ESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLUGfIFZAMOnt7NTR95+W0fffVexzk45MjKUWVSknIkTlT1xohwZGameIjCkiBRgoXg0qnBDg1o3blT3u++qt6NDJhqVnE45s7Lk8nqVP22axlZVye33y+HkoAjSE5ECLNMbDiu4fr1a16+X6e3tuzIeV/zIEfUcOaLDwaA69+zRaddeq4JZs9irQlrin1+ARXq7uhRct06HNm78eKCOI7J/v979xS8U/sMfhmF2wPAjUoAljDEK1dfr0HPPKd7d3e/79bS0aP9TTyly8OAQzg5IDSIFWCLW0aF9q1cPKFDHHG1qUvDXv1Y8Gh2CmQGpQ6QAS7Ru3Kh4T88p3z/8hz/oaFNTEmcEpB6RAizR9dZbUjx+yvfvaWlRtL09eRMCLECkAAuYeFwyJtXTAKxDpAALxHt6+nU238m0btigeCSShBkBdiBSgAXikYjiSYhUR2NjUh4HsAWRAiwQj0Q+uKIEgD6IFGCBZB3uA9INkQIsEI9EiBRwHEQKsEBPS4t6w+FUTwOwDpECLJC0SDkcg38MwCJECkgjRV/6kpxud6qnASQNkQLSiCs/n70ppBUiBaSRjJwcOYgU0giRAtJIRm4ue1JIK0QKSDFjzAfX7ksCZ05OUh4HsAWRAlItHk/a9fYycnLYk0JaIVJAiplYTLFT+KLD43E4nbwnhbRCpIAUM7HYKX0bL/BpQKSAFDPxOJECPgGRAlItFlPs6NFUzwKwEpECUqy3q0tH3nkn1dMArESkgBQzkYii77036MdxZGbKkZGRhBkB9iBSQJrIPfNMZY0fn+ppAElFpIA04XS75XC5Uj0NIKmIFJAmHB6PHJmZqZ4GkFRECkgTTrdbTvakkGaIFJBCSb1uH4f7kIaIFJBisSNHkvI4Trebw31IO0QKSLHerq7kPJDDwXX7kHaIFJBi8WRFCkhDRApIsd7OzlRPAbAWkQJSLFnvSQHpiEgBqWSM3v/d71I9C8BaRApIsWh7+6Afw+FyKbOgYNCPA9iGSAFpICMvT3mf+UyqpwEkHZEC0oDD6VRGdnaqpwEkHZEC0oHTKSeRQhoiUkAacGRkyJmVleppAElHpIAUMsYk5XE43Id0RaSAFIofPSolI1ROp5wez+AfB7AMkQJSKHbkSHIipQ/2poB0w59qIIViR44k7as6gHREpIAUiidxTwpIR0QKSKFkHu4D0hGRAlKo5/BhDvcBJ0CkgBQKv/aaTDQ66MfJmThx8JMBLESkgDSQP21aqqcADIkBRWrZsmWaOXOm8vPzVVxcrMsuu0x79+7tM6a7u1vV1dUqKipSXl6eqqqq1NLS0mfMvn37NGfOHOXk5Ki4uFh33HGHent7B/9qgE+pjLy8VE8BGBIDilRtba2qq6v1yiuvqKamRtFoVJdccom6PvT117fffrs2bNigtWvXqra2VgcOHNAVV1yRWB+LxTRnzhz19PTo5Zdf1lNPPaUnn3xSS5cuTd6rAj5lMnJzUz0FYEg4zCCuy3Lo0CEVFxertrZWX/jCFxQKhTRmzBg9/fTTuvLKKyVJb775ps4++2zV1dVp9uzZ2rx5s77+9a/rwIED8vv9kqTVq1frH//xH3Xo0CG53e6TPm84HJbP51MoFJLX6z3V6QMp986yZWqvqxv040x55BHlTZmShBkBw6O/v8cH9Z5UKBSSJBUWFkqSGhoaFI1GVVFRkRhz1llnqbS0VHV/+4tYV1enc845JxEoSaqsrFQ4HNaePXuO+zyRSEThcLjPDRjpjDFJO/3cxeE+pKlTjlQ8Htdtt92mz33ucyorK5MkBYNBud1uFXzkG0L9fr+CwWBizIcDdWz9sXXHs2zZMvl8vsRtwoQJpzptwBqmp0cmSe/F8jUdSFenHKnq6mrt3r1ba9asSeZ8jmvJkiUKhUKJW3Nz85A/JzDU4tGo4kmKFNftQ7pyncqdbrnlFm3cuFHbtm3T+PHjE8tLSkrU09Oj9vb2PntTLS0tKikpSYzZvn17n8c7dvbfsTEf5fF45OEKz0gz8UgkKZ+RAtLZgP75ZYzRLbfconXr1mnr1q2aNGlSn/XTp09XZmamtmzZkli2d+9e7du3T4FAQJIUCATU2Nio1tbWxJiamhp5vV5NnTp1MK8FGFHiPT2KEynghAa0J1VdXa2nn35azz77rPLz8xPvIfl8PmVnZ8vn82nBggVavHixCgsL5fV6tWjRIgUCAc2ePVuSdMkll2jq1Km65ppr9PDDDysYDOqee+5RdXU1e0v4VDHRKHtSwEkMKFKrVq2SJF100UV9lj/xxBP67ne/K0l69NFH5XQ6VVVVpUgkosrKSq1cuTIxNiMjQxs3btTNN9+sQCCg3NxczZ8/Xw888MDgXgkwwkTb2tTLmarACQ3qc1KpwuekkA4Obd6sfX/7h99g5EyerMkPPCAXH+jFCDIsn5MCkHrZpaVyuk7pHCjAekQKGOEycnMlhyPV0wCGBJECRjhndrbE56SQpviTDYxwGTk5crAnhTRFpIAUMMYoWecsZWRnc7gPaYt3W4FUMEbxnp6TDDEykoykDyfo2H8f23tyZmURKaQtIgWkgInFFD9y5IRjDkci+smePdrZ1iZfZqaKs7P1f/LzNW3UKE3My1ORx6P8zEw5HA4O9yFtESkgFeJxxY4ePeGQbJdLn/f7dVpOjrp6e9Xe06M/trVpU3OzPBkZOq+wUOVjxqgqHFZBPC4nJ08gDREpIAWMMYp3d59wTK7Lpa+NH6+4MYrG4+qOxdQRjepQd7f2hkLacvCglr/xhjb/0z/pxvfe03e+8x3lcBIF0gz/9AJSIRY76Z7UMU6HQ56MDPncbo3PzdV5hYX6zqRJWl5ergfOO0+jnU7dd999+va3v63a2lpFuR4g0giRAlIg1t2tI++8c0r3dTgccjocyna5VF5crBX/9E968MEH1d7erptuuklPPvmkjpzk/S5gpOBwH5AC8Z4eRfbvH/TjeE47TUVnnqmr/+7vNGPGDD3yyCN64IEH1NLSokWLFsnr9XL4DyMae1LACObKz1dGbq4yMjI0ZcoUrVy5UjfeeKMef/xxPfjgg+rs7Ez1FIFBYU8KGMGcbrecmZmSPjgM6PF4dPvtt8vtduuxxx7TmDFjtGjRIr6rDSMWkQJGMIfbLcffInVMXl6eqqur1dHRoRUrVui0007Tt7/9bbm4UjpGIA73ASOY0+2W0+3+2PLc3FzdddddmjVrlh5++GHt3r07aZdhAoYTkQKGmTFGSlIwnJmZH9uTkj449Jedna1/+Id/kDFGK1asUCwWS8pzAsOJSAEpEO/nZ6ROxpGZKccnHMZzOByaNm2avve97+n5559XbW0te1MYcYgUkAK9XV1Je6wTnWLucrk0d+5cTZkyRStWrFA4HE7a8wLDgUgBKRBLYqROxOFwaMyYMbrmmmu0a9cu7dixg70pjChECkiB2DB/fumb3/ymsrOz9T//8z/qPsk1AwGbECkgCYwxisVi/d5LSebhvv7Iy8vTZZddppdeekmHDx8e1ucGBoNIAUlgjNGOHTsUDAb7NT5UX5+cJ+7nJY9cLpcuvPBC/fWvf9X+/fs55IcRg0gBSRCNRvXjH/+432fQRQ4eHPRzOnNy5Jsxo19jHQ6HTj/9dJ122mmqT1YggWHAR9CBJPjTn/6k3/72tyosLNScOXOUn59/yo+1v6tLO9va1BGNakxWlgJjxij3eJ+FcjrlGsDzjBkzRhMnTlRDQ4OMMVx4FiMCkQIGyRij5557TkePHtW2bdu0f/9+TZkyZcARMMaoqbNT9+3cqb90dqo7FpM3M1Nlo0bp/86cqcyPfPOuw+mUMyur34/v9Xo1duxYNTQ0DGheQCpxuA8YpPfff1+bN29WLBbT22+/fcofmv1zZ6du+N//1RuhkI7GYjKSQtGo/re1VbfW1+vwR8/KczqVkZPT78fPzMyU3+/v9/tmgA2IFDBIr732mv785z9L+mBvaM2aNaf07bg/3bNHoU+43/b33lPNgQN9ljmcTmVkZ/f78R0Oh/Lz89Xb2zvguQGpQqSAQYhGo9q+fbtCoZBcLpe8Xq/27Nmj3bt3n/B+x7ve3oA5nXIOIFKS5Ha75XTy1x4jB39agUFoa2vTwYMHtXz5co0dO1YLFy7U3//936umpuaEF3Q9Y+nSQT+3w+GQc4DfExWNRhWPxwf93MBwIVLAIGRlZWnBggW68sor5fV6lZmZqR/84Af6yle+8omH/BwOx3EvCjtnwgRlfsLJFhPz8jStsHBQczXGqKuri7P6MKIQKWAQvF6vpk2bljgp4fDhw4rH4zr//PNP+G24mT6fCr/0pT7LKseN032f/ayyMjISfzEzHA4VeTz68cyZmlpQ0Ge8v6pqQHONxWI6fPiwRo0aNaD7AanEKejAIDgcDjkcDsXjcZ155plqbm5WZ2encnNzT3g/Z3a2RgUCCr36qmIdHYnHqhw3TuNzcrTx3Xd1uLtbE/PyNHfSJBV9JHiesWM1KhAY0F5RR0eHDhw4oDPPPHPgLxRIESIFJIHT6VRZWZm2bdumUCgkv99/wvEOh0Pez35WYy69VC2//rXM396/cjgcKhs1SmUn2NtxjRqlcddcI5fXO6A5trW1qbm5WZWVlRzyw4jB4T4gCRwOh8rKyvTee+/p3Xff7dfnpJwej/yXX67CL3/5E7+48KMy8vI0du5cFcyaJUdGRr/nZ4zRwYMH1dTUpNmzZ/f7fkCqESkgSSZMmKDJkyerpqam32fQZeTkaPz118tfVSV3cfEnjnNkZCh74kRNuOEGjbn0Ujnd7gHNLR6Pq76+XoWFhTrjjDMGdF8glTjcBySBw+GQ3+/XzJkztWnTJt19993Ky8vr1/1cubkae+WV8k6bpvdfflmde/YoEgwqHokoIy9PWePHyzdjhnwzZii7tPSUDtX19PRow4YNuuCCC1RUVMThPowYRApIkqysLFVUVGj9+vV6/vnnVVVV1e8YOD0e5ZWVKffMMxU7elSmt1cmHpcjI0POzEw5c3Lk7OchweOpq6vTO++8o+9///v9iidgCyIFJInD4dDnP/95feYzn9ETTzyhiy66SKNHjx7Q/R0ez4A/oHsynZ2dWrlypc4++2x94QtfYC8KIwrvSQFJlJ+fr+9973vauXOnXnjhhRNedWI4GGO0ZcsWbd++XVddddVJzzoEbEOkgCSrqKjQhRdeqJUrV6qpqSll34JrjNFf/vIXrVq1SlOmTNHcuXPZi8KIQ6SAJMvOztbtt9+utrY2/eQnP1H3R79iY5hEo1GtXr1ab731lu666y7lDOBrPQBbECkgyRwOh8477zzddNNN2rBhg/793/9dvb29w7pHFY/H9atf/UrPPPOMrrvuOpWXl7MXhRGJEyeAIZCVlaWbbrpJTU1N+tGPfqSsrCxdddVVyszMHPJY9PT06Nlnn9V9992nyspKLVq0SNkD/EoPwBZEChgi2dnZWrp0qQ4fPqwf/vCHisfjmjdvnjKT8V1SnyAWi+nXv/617r//fpWVlemHP/yhCj5yYVpgJOFwHzCECgsL9S//8i+64IILdP/992v58uVqb29P+qG/Y1/D8fjjj+vuu+9WWVmZfvSjH3E2H0Y8IgUMIYfDodLSUj322GOaO3euHn30UV1//fV69dVXk3ZCRU9Pj3bu3KmFCxfqwQcf1Ne//nWtWrVKkydP5n0ojHgOk6rzYwchHA7L5/MpFArJO8ArQQOpcvToUT3zzDNatWqVOjs7ddVVV+myyy5TWVmZMgZwsdhjYrGY3n77bf3mN7/Rf/zHf0iSbrrpJn33u9896VeFAKnW39/jRAoYRtFoVE1NTVq9erX+67/+S16vVxdddJHmzp2rc889V9nZ2crIyJDT6Ux8V5UxRsYYxeNxxWIxdXd36/XXX9cvf/lLbd26VYcPH9bll1+uhQsX6swzz5R7gBefBVKBSAGWMsYoFovpjTfe0L/927/ppZde0rvvvqvc3FyVl5errKxMpaWlKioqUmZmpnp7e/X++++rublZjY2Nqq+vVygU0vjx43XhhRfq2muvVVlZmVwuF4f3MGIQKWAEiMfjeuedd9TQ0KDGxkY1NjaqqalJhw4dUldXl+LxuDIyMpSTk6PRo0fr9NNPV1lZmc455xzNmDFDkydPPqVDhUCqESlgBDHGKBKJqKOjQ11dXYpEIopGozLGfPB1Hi6XPB6PcnNz5fV65fF42GvCiNbf3+N8TgqwgMPhUFZWlrKysjRmzJhUTwewBqegAwCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYaUKRWrVqladOmyev1yuv1KhAIaPPmzYn13d3dqq6uVlFRkfLy8lRVVaWWlpY+j7Fv3z7NmTNHOTk5Ki4u1h133KHe3t7kvBoAQFoZUKTGjx+vhx56SA0NDXr11Vf15S9/Wd/61re0Z88eSdLtt9+uDRs2aO3ataqtrdWBAwd0xRVXJO4fi8U0Z84c9fT06OWXX9ZTTz2lJ598UkuXLk3uqwIApAczSKNGjTI///nPTXt7u8nMzDRr165NrHvjjTeMJFNXV2eMMWbTpk3G6XSaYDCYGLNq1Srj9XpNJBLp93OGQiEjyYRCocFOHwCQAv39PX7K70nFYjGtWbNGXV1dCgQCamhoUDQaVUVFRWLMWWedpdLSUtXV1UmS6urqdM4558jv9yfGVFZWKhwOJ/bGjicSiSgcDve5AQDS34Aj1djYqLy8PHk8Hi1cuFDr1q3T1KlTFQwG5Xa7VVBQ0Ge83+9XMBiUJAWDwT6BOrb+2LpPsmzZMvl8vsRtwoQJA502AGAEGnCkpkyZol27dqm+vl4333yz5s+fr9dff30o5pawZMkShUKhxK25uXlInw8AYIcBf+mh2+3WGWecIUmaPn26duzYoZ/97GeaO3euenp61N7e3mdvqqWlRSUlJZKkkpISbd++vc/jHTv779iY4/F4PPJ4PAOdKgBghBv056Ti8bgikYimT5+uzMxMbdmyJbFu79692rdvnwKBgCQpEAiosbFRra2tiTE1NTXyer2aOnXqYKcCAEgzA9qTWrJkiS699FKVlpaqo6NDTz/9tF566SW98MIL8vl8WrBggRYvXqzCwkJ5vV4tWrRIgUBAs2fPliRdcsklmjp1qq655ho9/PDDCgaDuueee1RdXc2eEgDgYwYUqdbWVl177bU6ePCgfD6fpk2bphdeeEFf+cpXJEmPPvqonE6nqqqqFIlEVFlZqZUrVybun5GRoY0bN+rmm29WIBBQbm6u5s+frwceeCC5rwoAkBYcxhiT6kkMVDgcls/nUygUktfrTfV0AAAD1N/f41y7DwBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1BhWphx56SA6HQ7fddltiWXd3t6qrq1VUVKS8vDxVVVWppaWlz/327dunOXPmKCcnR8XFxbrjjjvU29s7mKkAANLQKUdqx44devzxxzVt2rQ+y2+//XZt2LBBa9euVW1trQ4cOKArrrgisT4Wi2nOnDnq6enRyy+/rKeeekpPPvmkli5deuqvAgCQnswp6OjoMJMnTzY1NTXmi1/8orn11luNMca0t7ebzMxMs3bt2sTYN954w0gydXV1xhhjNm3aZJxOpwkGg4kxq1atMl6v10QikX49fygUMpJMKBQ6lekDAFKsv7/HT2lPqrq6WnPmzFFFRUWf5Q0NDYpGo32Wn3XWWSotLVVdXZ0kqa6uTuecc478fn9iTGVlpcLhsPbs2XPc54tEIgqHw31uAID05xroHdasWaPXXntNO3bs+Ni6YDAot9utgoKCPsv9fr+CwWBizIcDdWz9sXXHs2zZMt1///0DnSoAYIQb0J5Uc3Ozbr31Vv3nf/6nsrKyhmpOH7NkyRKFQqHErbm5edieGwCQOgOKVENDg1pbW3X++efL5XLJ5XKptrZWy5cvl8vlkt/vV09Pj9rb2/vcr6WlRSUlJZKkkpKSj53td+znY2M+yuPxyOv19rkBANLfgCJ18cUXq7GxUbt27UrcZsyYoXnz5iX+OzMzU1u2bEncZ+/evdq3b58CgYAkKRAIqLGxUa2trYkxNTU18nq9mjp1apJeFgAgHQzoPan8/HyVlZX1WZabm6uioqLE8gULFmjx4sUqLCyU1+vVokWLFAgENHv2bEnSJZdcoqlTp+qaa67Rww8/rGAwqHvuuUfV1dXyeDxJelkAgHQw4BMnTubRRx+V0+lUVVWVIpGIKisrtXLlysT6jIwMbdy4UTfffLMCgYByc3M1f/58PfDAA8meCgBghHMYY0yqJzFQ4XBYPp9PoVCI96cAYATq7+9xrt0HALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALCWK9UTOBXGGElSOBxO8UwAAKfi2O/vY7/PP8mIjNThw4clSRMmTEjxTAAAg9HR0SGfz/eJ60dkpAoLCyVJ+/btO+GL+7QLh8OaMGGCmpub5fV6Uz0da7Gd+oft1D9sp/4xxqijo0Pjxo074bgRGSmn84O30nw+H38I+sHr9bKd+oHt1D9sp/5hO51cf3YyOHECAGAtIgUAsNaIjJTH49F9990nj8eT6qlYje3UP2yn/mE79Q/bKbkc5mTn/wEAkCIjck8KAPDpQKQAANYiUgAAaxEpAIC1RmSkVqxYoYkTJyorK0vl5eXavn17qqc0rLZt26ZvfOMbGjdunBwOh9avX99nvTFGS5cu1dixY5Wdna2Kigq9/fbbfca0tbVp3rx58nq9Kigo0IIFC9TZ2TmMr2JoLVu2TDNnzlR+fr6Ki4t12WWXae/evX3GdHd3q7q6WkVFRcrLy1NVVZVaWlr6jNm3b5/mzJmjnJwcFRcX64477lBvb+9wvpQhtWrVKk2bNi3xwdNAIKDNmzcn1rONju+hhx6Sw+HQbbfdlljGthoiZoRZs2aNcbvd5he/+IXZs2ePueGGG0xBQYFpaWlJ9dSGzaZNm8zdd99t/vu//9tIMuvWreuz/qGHHjI+n8+sX7/e/OEPfzDf/OY3zaRJk8zRo0cTY7761a+ac88917zyyivmd7/7nTnjjDPM1VdfPcyvZOhUVlaaJ554wuzevdvs2rXLfO1rXzOlpaWms7MzMWbhwoVmwoQJZsuWLebVV181s2fPNhdccEFifW9vrykrKzMVFRVm586dZtOmTWb06NFmyZIlqXhJQ+I3v/mNee6558xbb71l9u7da+666y6TmZlpdu/ebYxhGx3P9u3bzcSJE820adPMrbfemljOthoaIy5Ss2bNMtXV1YmfY7GYGTdunFm2bFkKZ5U6H41UPB43JSUl5pFHHkksa29vNx6PxzzzzDPGGGNef/11I8ns2LEjMWbz5s3G4XCY/fv3D9vch1Nra6uRZGpra40xH2yTzMxMs3bt2sSYN954w0gydXV1xpgP/jHgdDpNMBhMjFm1apXxer0mEokM7wsYRqNGjTI///nP2UbH0dHRYSZPnmxqamrMF7/4xUSk2FZDZ0Qd7uvp6VFDQ4MqKioSy5xOpyoqKlRXV5fCmdmjqalJwWCwzzby+XwqLy9PbKO6ujoVFBRoxowZiTEVFRVyOp2qr68f9jkPh1AoJOn/X5y4oaFB0Wi0z3Y666yzVFpa2mc7nXPOOfL7/YkxlZWVCofD2rNnzzDOfnjEYjGtWbNGXV1dCgQCbKPjqK6u1pw5c/psE4k/T0NpRF1g9r333lMsFuvzP1mS/H6/3nzzzRTNyi7BYFCSjruNjq0LBoMqLi7us97lcqmwsDAxJp3E43Hddttt+tznPqeysjJJH2wDt9utgoKCPmM/up2Otx2PrUsXjY2NCgQC6u7uVl5entatW6epU6dq165dbKMPWbNmjV577TXt2LHjY+v48zR0RlSkgFNRXV2t3bt36/e//32qp2KlKVOmaNeuXQqFQvrVr36l+fPnq7a2NtXTskpzc7NuvfVW1dTUKCsrK9XT+VQZUYf7Ro8erYyMjI+dMdPS0qKSkpIUzcoux7bDibZRSUmJWltb+6zv7e1VW1tb2m3HW265RRs3btSLL76o8ePHJ5aXlJSop6dH7e3tfcZ/dDsdbzseW5cu3G63zjjjDE2fPl3Lli3Tueeeq5/97Gdsow9paGhQa2urzj//fLlcLrlcLtXW1mr58uVyuVzy+/1sqyEyoiLldrs1ffp0bdmyJbEsHo9ry5YtCgQCKZyZPSZNmqSSkpI+2ygcDqu+vj6xjQKBgNrb29XQ0JAYs3XrVsXjcZWXlw/7nIeCMUa33HKL1q1bp61bt2rSpEl91k+fPl2ZmZl9ttPevXu1b9++PtupsbGxT9Bramrk9Xo1derU4XkhKRCPxxWJRNhGH3LxxRersbFRu3btStxmzJihefPmJf6bbTVEUn3mxkCtWbPGeDwe8+STT5rXX3/d3HjjjaagoKDPGTPprqOjw+zcudPs3LnTSDI/+clPzM6dO81f//pXY8wHp6AXFBSYZ5991vzxj3803/rWt457CvpnP/tZU19fb37/+9+byZMnp9Up6DfffLPx+XzmpZdeMgcPHkzcjhw5khizcOFCU1paarZu3WpeffVVEwgETCAQSKw/dsrwJZdcYnbt2mWef/55M2bMmLQ6ZfjOO+80tbW1pqmpyfzxj380d955p3E4HOa3v/2tMYZtdCIfPrvPGLbVUBlxkTLGmMcee8yUlpYat9ttZs2aZV555ZVUT2lYvfjii0bSx27z5883xnxwGvq9995r/H6/8Xg85uKLLzZ79+7t8xiHDx82V199tcnLyzNer9dcd911pqOjIwWvZmgcb/tIMk888URizNGjR833v/99M2rUKJOTk2Muv/xyc/DgwT6P85e//MVceumlJjs724wePdr84Ac/MNFodJhfzdC5/vrrzemnn27cbrcZM2aMufjiixOBMoZtdCIfjRTbamjwVR0AAGuNqPekAACfLkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBY6/8Bry5JqNu9Tv0AAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "-1.8403543992077669"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "for i in range(6000):\n",
    "    teacher.train(*teacher.get_data())\n",
    "\n",
    "    if i % 500==0:\n",
    "        test_result = sum([teacher.test(play=False) for _ in range(10)])/10\n",
    "        print(f'第 {i} 轮：{test_result}')\n",
    "\n",
    "teacher.test(play=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<__main__.PPO at 0x16200f92340>"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 初始化学生模型\n",
    "student = PPO()\n",
    "\n",
    "student"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-0.70563424,  0.7085763 , -0.9421133 ], dtype=float32)"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "env.reset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.4535],\n",
       "        [0.4057]], device='cuda:0', grad_fn=<SigmoidBackward0>)"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 定义鉴别网络，它的任务是鉴别一批数据是来自teacher 还是student\n",
    "from typing import Any\n",
    "\n",
    "\n",
    "class Discriminator(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.sequential = torch.nn.Sequential(\n",
    "            torch.nn.Linear(4, 128),  # 将[state action] [3+1]\n",
    "            torch.nn.ReLU(),\n",
    "            torch.nn.Linear(128, 1),\n",
    "            torch.nn.Sigmoid(),\n",
    "        )  # 范围0~1\n",
    "\n",
    "    def forward(self, states, actions):\n",
    "        cat = torch.cat([states, actions], dim=1)\n",
    "        return self.sequential(cat)\n",
    "\n",
    "\n",
    "discriminator = Discriminator().to(device='cuda')\n",
    "discriminator(torch.randn(2, 3).to(device='cuda'), torch.rand(2, 1).to(device='cuda'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 -1309.2257431874823\n",
      "500 -1324.0444243828188\n",
      "1000 -1049.550021565065\n",
      "1500 -644.5256335601496\n"
     ]
    }
   ],
   "source": [
    "# 模仿学习\n",
    "def copy_learn():\n",
    "    optimizer = torch.optim.Adam(discriminator.parameters(),lr=1e-4)\n",
    "    bce_loss = torch.nn.BCELoss()\n",
    "\n",
    "\n",
    "    for i in range(2000):\n",
    "        #使用训练好的模型得到一批老师数据\n",
    "        with torch.no_grad():\n",
    "            teacher_states,_,teacher_actions,_,_ = teacher.get_data()\n",
    "\n",
    "        #使用学生模型获取一局游戏数据，不需要reward\n",
    "        states,_,actions,next_states,overs = student.get_data()\n",
    "\n",
    "\n",
    "        #使用鉴别器鉴定两批数据\n",
    "        prob_teacher = discriminator(teacher_states, teacher_actions)\n",
    "        prob_student = discriminator(states, actions)\n",
    "\n",
    "        #老师的用0表示,学生的用1表示,计算二分类loss\n",
    "        loss_teacher = bce_loss(prob_teacher, torch.zeros_like(prob_teacher))\n",
    "        loss_student = bce_loss(prob_student, torch.ones_like(prob_student))\n",
    "        loss = loss_teacher + loss_student\n",
    "\n",
    "        #调整鉴别器的loss\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        #使用一批数据来自学生的概率作为reward,取log,再符号取反\n",
    "        #因为鉴别器会把学生数据的概率贴近1,所以目标是让鉴别器无法分辨,这是一种对抗网络的思路\n",
    "        rewards = -prob_student.log().detach()\n",
    "\n",
    "        #消除模型中的reward偏移\n",
    "        rewards = rewards * 8 - 8\n",
    "\n",
    "        #更新学生模型参数,使用PPO模型本身的更新方式\n",
    "        student.train(states, rewards, actions, next_states, overs)\n",
    "\n",
    "        if i % 500 == 0:\n",
    "            test_result = sum([student.test(play=False)\n",
    "                               for _ in range(10)]) / 10\n",
    "\n",
    "            print(i, test_result)\n",
    "\n",
    "\n",
    "copy_learn()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAakAAAGiCAYAAABd6zmYAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAj6klEQVR4nO3df3DU9YH/8ddns9mFJOyGBLMxkEB6ctJ8+WEFhK1X7ZWUtE17tdIbzy+1nMfYr15wRG6ck57i6NxMHP1+z9Y7ip3ptDpzVRx64ikHtbmAsdYYMBAF1By9YpMCm/DD7CaBbJLd9/cPZc9VsAnZ7L43eT5mdkY+n/d+8v58lDzd/Xz2s44xxggAAAu5Mj0BAAAuhkgBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKyVsUht3rxZc+bM0ZQpU7Rs2TLt3bs3U1MBAFgqI5F69tlntWHDBj3wwAPav3+/Fi1apJqaGnV3d2diOgAASzmZuMHssmXLtHTpUv3Lv/yLJCkej6u8vFx33nmn7r333nRPBwBgKXe6f+Dg4KBaW1u1cePGxDKXy6Xq6mo1Nzdf8DnRaFTRaDTx53g8rjNnzqi4uFiO44z7nAEAqWWMUW9vr8rKyuRyXfxNvbRH6tSpU4rFYgoEAknLA4GA3n333Qs+p76+Xg8++GA6pgcASKPOzk7NmjXrouvTHqlLsXHjRm3YsCHx53A4rIqKCnV2dsrn82VwZgCASxGJRFReXq5p06Z96ri0R2rGjBnKyclRV1dX0vKuri6VlpZe8Dler1der/cTy30+H5ECgCz2x07ZpP3qPo/Ho8WLF6uxsTGxLB6Pq7GxUcFgMN3TAQBYLCNv923YsEFr1qzRkiVLdM011+gHP/iB+vv7deutt2ZiOgAAS2UkUjfddJNOnjypTZs2KRQK6aqrrtIvf/nLT1xMAQCY3DLyOamxikQi8vv9CofDnJMCgCw00t/j3LsPAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLVGHalXXnlF3/jGN1RWVibHcfT8888nrTfGaNOmTbr88ss1depUVVdX68iRI0ljzpw5o9WrV8vn86mwsFBr165VX1/fmHYEADDxjDpS/f39WrRokTZv3nzB9Y888ogef/xxPfHEE2ppaVF+fr5qamo0MDCQGLN69WodPnxYDQ0N2rFjh1555RV973vfu/S9AABMTGYMJJnt27cn/hyPx01paal59NFHE8t6enqM1+s1zzzzjDHGmLfffttIMvv27UuM2bVrl3Ecxxw7dmxEPzccDhtJJhwOj2X6AIAMGenv8ZSekzp69KhCoZCqq6sTy/x+v5YtW6bm5mZJUnNzswoLC7VkyZLEmOrqarlcLrW0tFxwu9FoVJFIJOkBAJj4UhqpUCgkSQoEAknLA4FAYl0oFFJJSUnSerfbraKiosSYj6uvr5ff7088ysvLUzltAIClsuLqvo0bNyocDicenZ2dmZ4SACANUhqp0tJSSVJXV1fS8q6ursS60tJSdXd3J60fHh7WmTNnEmM+zuv1yufzJT0AABNfSiNVWVmp0tJSNTY2JpZFIhG1tLQoGAxKkoLBoHp6etTa2poYs3v3bsXjcS1btiyV0wEAZDn3aJ/Q19en3/72t4k/Hz16VG1tbSoqKlJFRYXWr1+vf/zHf9TcuXNVWVmp+++/X2VlZbrhhhskSZ/97Gf1la98RbfddpueeOIJDQ0Nad26dfqrv/orlZWVpWzHAAATwGgvG9yzZ4+R9InHmjVrjDEfXIZ+//33m0AgYLxer1mxYoVpb29P2sbp06fNzTffbAoKCozP5zO33nqr6e3tTfmliwAAO43097hjjDEZbOQliUQi8vv9CofDnJ8CgCw00t/jWXF1HwBgciJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLVGfRd0ACNj4nHFo1H1vvWWIm1tGvjDHxQfHJTb71fBlVeqcPlyeQMBKSdHjuNkerqAlYgUMA5i0ah633xToW3b1H/kiGTMB48PhVtaFHruORX/+Z/rstpaTbn88gzOFrAXkQJSzMRiOvWrX6nr3/5NQ2fOXGSQUay3V90vvKCBY8c066//WlNnz07vRIEswDkpIIVMPK7TTU068fTTFw/Ux0QOHNAfnnpKAydOKAu/OQcYV0QKSKG+d97RiZ//XLH+/pE/KR5X5I03FPrFL2SGh8dvckAWIlJAisSjUb3/6qsaPHnykp5/urFR0a6uFM8KyG5ECkiRc++9p9O7d1/6BuJx/eGnP03dhIAJgEgBKRIfHlb83LkxbWO4p0fxoaEUzQjIfkQKsEh8cFDD4XCmpwFYg0gBFomdPct5KeAjiBRgkaFTpxRpa8v0NABrECkAgLWIFJAiLrdbrqlTx74hY/hQL/AhIgWkiOfyy+W7+uoxbyd29qwMV/gBkogUkDIur1dun2/M2xkOhxWPRlMwIyD7ESkgRVweT0oiFQ2FNNzXl4IZAdmPSAEp4jiOUvGtUGePHBnxzWmBiY5IAQCsRaSAFHL7/XI8njFvxwwPc4UfICIFpNS0RYvkmTFjzNu51DupAxMNkQJSyO3zyZWCV1IDx49L8XgKZgRkNyIFpJB72rSUROrkiy/KxGIpmBGQ3YgUkEKOyyU5Y7/Gj/NRwAeIFADAWkQKSLEpM2eOfSPG8IFeQEQKSLnp110nucb4V8sYrvADRKSAlPOUlIz5vJQxRgMdHSmaEZC9iBSQYp7p08d+e6RYTGdefjkFswGyG5ECUm2sb/UBSOBvE2ApY4wMH+jFJEekgBRzXC5NnTNnzNuJR6Ma7u0d+4SALEakgBRz3G75ly4d83bi0aiGI5EUzAjIXkQKSDXHkScQGPNmBjo7Fd63LwUTArIXkQJSzXGU6/ePfTvGcP8+THpECkgxx3FSeoUf9/HDZEakAIvFBwf5yg5MakQKGAe5fn9KzksNv//+B6ECJikiBYyD3OLilFyGPnjmjOLR6NgnBGQpIgWMg5ypU+X2+ca8nUhrqwZPnUrBjIDsRKSAceB4PMqZOnXsG+KiCUxyRAoYBym9ws8YrvDDpEWkAMsNnj6d6SkAGUOkgHGSV1mpnLy8MW9nsLs7BbMBshORAsZJ3p/8iXIKCsa8ne4XXuCzUpi0iBQwTnKnT5eTmzvm7Qz39aVgNkB2IlLAOMnJz5fjdmd6GkBWI1LAOHFcrrF/jbwkGcNdJzBpESnAdsZo8OTJTM8CyAgiBYwj35IlY96GiceJFCYtIgWMo4KqqjFvwwwO6lRDQwpmA2QfIgWMI28gIDljPzNlhoZSMBsg+xApYBy5U/ENvR/i1kiYjEYVqfr6ei1dulTTpk1TSUmJbrjhBrW3tyeNGRgYUF1dnYqLi1VQUKBVq1apq6sraUxHR4dqa2uVl5enkpIS3XPPPRoeHh773gCWcVJ0/754NKr4uXMp2RaQTUb1N6ipqUl1dXV6/fXX1dDQoKGhIa1cuVL9/f2JMXfffbdefPFFbdu2TU1NTTp+/LhuvPHGxPpYLKba2loNDg7qtdde01NPPaUnn3xSmzZtSt1eARNM7Nw5Dff2ZnoaQNo5ZgzvIZw8eVIlJSVqamrSddddp3A4rMsuu0xPP/20vv3tb0uS3n33XX32s59Vc3Ozli9frl27dunrX/+6jh8/rsCH31z6xBNP6O///u918uRJeTyeP/pzI5GI/H6/wuGwfCn4zh5gvMQGBvS7Rx5R5I03xrQd78yZqtywQflz56ZoZkBmjfT3+JjeiwiHw5KkoqIiSVJra6uGhoZUXV2dGDNv3jxVVFSoublZktTc3KwFCxYkAiVJNTU1ikQiOnz48AV/TjQaVSQSSXoA2cBxu5X3mc+MeTvRY8fU/7G31oHJ4JIjFY/HtX79el177bWaP3++JCkUCsnj8aiwsDBpbCAQUCgUSoz5aKDOrz+/7kLq6+vl9/sTj/Ly8kudNpBWjsul3OLi1GyMCycwCV1ypOrq6nTo0CFt3bo1lfO5oI0bNyocDicenZ2d4/4zgZRwHLmnTUvJpkwsxhV+mHQu6e6X69at044dO/TKK69o1qxZieWlpaUaHBxUT09P0quprq4ulZaWJsbs3bs3aXvnr/47P+bjvF6vvF7vpUwVyCjHcVJ2hd/Q++/LDA3JGcF5W2CiGNXfHmOM1q1bp+3bt2v37t2qrKxMWr948WLl5uaqsbExsay9vV0dHR0KBoOSpGAwqIMHD6r7I1/k1tDQIJ/Pp6oUfDofmKiG3n9fcT7Ui0lmVJGqq6vTv/7rv+rpp5/WtGnTFAqFFAqFdO7Dz2/4/X6tXbtWGzZs0J49e9Ta2qpbb71VwWBQy5cvlyStXLlSVVVVuuWWW/Tmm2/qpZde0n333ae6ujpeLWFCmjJzpqam4OKJnpYWDX94sRIwWYwqUlu2bFE4HNYXv/hFXX755YnHs88+mxjz2GOP6etf/7pWrVql6667TqWlpXruuecS63NycrRjxw7l5OQoGAzqO9/5jr773e/qoYceSt1eARZx+3zK/fAK2LGInzsnE4ulYEZA9hjVOamRnLSdMmWKNm/erM2bN190zOzZs7Vz587R/Ggga7mmTpU7Pz/T0wCyEvfuA8aZy+uVa+rUlGxruK+PK/wwqRApYJw5KbgL+nmDF/ksITBRESkgHVIUqujHbtYMTHRECkiDwqVLlVNQMObtnObLDzHJECkgDTwlJXKl4EO48cHBFMwGyB5ECkgD9/TpctyXdIOXT4rHU7MdIAsQKSAN3AUFcnJyxrwdE4tp6P33UzAjIDsQKSANUnWFn4nHNXjyZEq2BWQDIgWkSwpuNBsfGND7r72WgskA2YFIAWkyo6Zm7BuJxzV05szYtwNkCSIFpIn3Il9FA+DiiBSQJt6SkpRsxwwP85UdmDSIFJAmbr8/JduJ9fcr1teXkm0BtiNSQJaJ9fdrmEhhkiBSQJo4jpOSD/Se6+jQ2d/9LgUzAuxHpIA0ceXlqfDaa8e8HTM0JBONpmBGgP2IFJAmTk6OPCn4hl5JMhrZl5AC2Y5IAWniuFzKLS5OybaGIxHu4YdJgUgB6eJyyZ2Cr+uQpKGTJxUfHk7JtgCbESkgTUZz/z5jjDr6+vTjd9/VU7/9rfafPq2BWCzxFl94/37Fz54dr6kC1kjRdwcAGAknN1eO2y0zgldBpXl5+svKSvUPD2v/6dP6dVeXrgsE9L8KC6WuLj7Qi0mBSAFpNHXOHE2trNTZI0c+dZzjOPI4joq8XhV5vSrPz9eZaFQvHTumE+fO6QuBQJpmDGQWb/cBaZSTn6+cvLxLem6R16sbZ8+WjNFvuroUHRjgCj9MeEQKSCN3fr5yxnDxhMfl0pfKynQ6GtVbra0pnBlgJyIFpJHL65XL4/nUMcf6+7Wjs1PP/O53+s/jx9X/kXNPjuNoSk6OVs6cqX979lkNcV4KExznpIAMMMbISHL0P1f9GWN0tK9PDxw4oPf6+jQQi8mXm6v506fr/y5dqtyPfGniZVOmyNfVpUOHDunqq6/OzE4AacArKSDN3NOmKRKLqfH4ccU+ck7pd319uu03v9E74bDOxWIyksJDQ/pNd7fuamnR6YGBpO1c5fXqNb6lFxMckQLSzL94sd4bGtJzv/+9QufOJZb/4PBhhS/y9t3eU6fUcPx40rICt1v9/f3jOlcg04gUkGbuoiLtPXNG3QMDOnb27CVfoec4jnJzc1M8O8AunJMC0sw7Y4aunzlT3nPnVJmfnzg3NVrGGA1zayRMcLySAtLMlZenq7/9beW73XK5XHJ9eOFEbXm5ci9y66Q5BQVa+LE7qEfjcblT8P1UgM2IFJBmjuNo9re+pVnXX6+uc+cSb/fVlJXpgc99TlNychJ/MXMcR8Ver/7f0qWqKixM2k7XlVdqwYIF6Z08kGb8bxiQAd6CAs2rrlbL/v260hi5HUeO46imrEyz8vK04w9/0OmBAc0pKNBNlZUq9nqTnj84bZoO9fXpf//Zn2VoD4D0IFJABjiOoyVf+IL2vPCCjr33nmbn5yeWz58+XfOnT7/oc11+v/b5fPry176mKVOmpGvKQEbwdh+QIb6iIq39/vfVIOlUNDqiq/ycvDz996xZildW6gvXXTeqr/8AshGRAjJo9mc+o//z6KPanZendwcGkj7c+1FOTo5iJSU6PGeOjhYV6S9vvlkFKfoCRcBmRArIIMdxtGDRIn3nwQcVuuYaPdfbq/c8HsU8HslxlDNtmpzKSh0uK9NOj0exP/1TrVmzRjNnzuRVFCYFzkkBGZaTk6P5Cxao8jOf0X8fOaL/3LlTu958U7GcHCkaVX5fn6790pe05vOf18yZM/kALyYVIgVYwHEcFRQUaOFVV2nhVVd96jhgMiFSgEWIEJCMc1IAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArDWqSG3ZskULFy6Uz+eTz+dTMBjUrl27EusHBgZUV1en4uJiFRQUaNWqVerq6kraRkdHh2pra5WXl6eSkhLdc889Gh4eTs3eAAAmlFFFatasWXr44YfV2tqqN954Q1/60pf0zW9+U4cPH5Yk3X333XrxxRe1bds2NTU16fjx47rxxhsTz4/FYqqtrdXg4KBee+01PfXUU3ryySe1adOm1O4VAGBiMGM0ffp085Of/MT09PSY3Nxcs23btsS6d955x0gyzc3Nxhhjdu7caVwulwmFQokxW7ZsMT6fz0Sj0RH/zHA4bCSZcDg81ukDADJgpL/HL/mcVCwW09atW9Xf369gMKjW1lYNDQ2puro6MWbevHmqqKhQc3OzJKm5uVkLFixQIBBIjKmpqVEkEkm8GruQaDSqSCSS9AAATHyjjtTBgwdVUFAgr9er22+/Xdu3b1dVVZVCoZA8Ho8KCwuTxgcCAYVCIUlSKBRKCtT59efXXUx9fb38fn/iUV5ePtppAwCy0KgjdeWVV6qtrU0tLS264447tGbNGr399tvjMbeEjRs3KhwOJx6dnZ3j+vMAAHZwj/YJHo9HV1xxhSRp8eLF2rdvn374wx/qpptu0uDgoHp6epJeTXV1dam0tFSSVFpaqr179yZt7/zVf+fHXIjX65XX6x3tVAEAWW7Mn5OKx+OKRqNavHixcnNz1djYmFjX3t6ujo4OBYNBSVIwGNTBgwfV3d2dGNPQ0CCfz6eqqqqxTgUAMMGM6pXUxo0b9dWvflUVFRXq7e3V008/rZdfflkvvfSS/H6/1q5dqw0bNqioqEg+n0933nmngsGgli9fLklauXKlqqqqdMstt+iRRx5RKBTSfffdp7q6Ol4pAQA+YVSR6u7u1ne/+12dOHFCfr9fCxcu1EsvvaQvf/nLkqTHHntMLpdLq1atUjQaVU1NjX70ox8lnp+Tk6MdO3bojjvuUDAYVH5+vtasWaOHHnootXsFAJgQHGOMyfQkRisSicjv9yscDsvn82V6OgCAURrp73Hu3QcAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWmOK1MMPPyzHcbR+/frEsoGBAdXV1am4uFgFBQVatWqVurq6kp7X0dGh2tpa5eXlqaSkRPfcc4+Gh4fHMhUAwAR0yZHat2+ffvzjH2vhwoVJy++++269+OKL2rZtm5qamnT8+HHdeOONifWxWEy1tbUaHBzUa6+9pqeeekpPPvmkNm3adOl7AQCYmMwl6O3tNXPnzjUNDQ3m+uuvN3fddZcxxpienh6Tm5trtm3blhj7zjvvGEmmubnZGGPMzp07jcvlMqFQKDFmy5YtxufzmWg0OqKfHw6HjSQTDocvZfoAgAwb6e/xS3olVVdXp9raWlVXVyctb21t1dDQUNLyefPmqaKiQs3NzZKk5uZmLViwQIFAIDGmpqZGkUhEhw8fvuDPi0ajikQiSQ8AwMTnHu0Ttm7dqv3792vfvn2fWBcKheTxeFRYWJi0PBAIKBQKJcZ8NFDn159fdyH19fV68MEHRztVAECWG9Urqc7OTt111136+c9/rilTpozXnD5h48aNCofDiUdnZ2fafjYAIHNGFanW1lZ1d3fr6quvltvtltvtVlNTkx5//HG53W4FAgENDg6qp6cn6XldXV0qLS2VJJWWln7iar/zfz4/5uO8Xq98Pl/SAwAw8Y0qUitWrNDBgwfV1taWeCxZskSrV69O/HNubq4aGxsTz2lvb1dHR4eCwaAkKRgM6uDBg+ru7k6MaWhokM/nU1VVVYp2CwAwEYzqnNS0adM0f/78pGX5+fkqLi5OLF+7dq02bNigoqIi+Xw+3XnnnQoGg1q+fLkkaeXKlaqqqtItt9yiRx55RKFQSPfdd5/q6urk9XpTtFsAgIlg1BdO/DGPPfaYXC6XVq1apWg0qpqaGv3oRz9KrM/JydGOHTt0xx13KBgMKj8/X2vWrNFDDz2U6qkAALKcY4wxmZ7EaEUiEfn9foXDYc5PAUAWGunvce7dBwCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwljvTE7gUxhhJUiQSyfBMAACX4vzv7/O/zy8mKyN1+vRpSVJ5eXmGZwIAGIve3l75/f6Lrs/KSBUVFUmSOjo6PnXnJrtIJKLy8nJ1dnbK5/NlejrW4jiNDMdpZDhOI2OMUW9vr8rKyj51XFZGyuX64FSa3+/nP4IR8Pl8HKcR4DiNDMdpZDhOf9xIXmRw4QQAwFpECgBgrayMlNfr1QMPPCCv15vpqViN4zQyHKeR4TiNDMcptRzzx67/AwAgQ7LylRQAYHIgUgAAaxEpAIC1iBQAwFpZGanNmzdrzpw5mjJlipYtW6a9e/dmekpp9corr+gb3/iGysrK5DiOnn/++aT1xhht2rRJl19+uaZOnarq6modOXIkacyZM2e0evVq+Xw+FRYWau3aterr60vjXoyv+vp6LV26VNOmTVNJSYluuOEGtbe3J40ZGBhQXV2diouLVVBQoFWrVqmrqytpTEdHh2pra5WXl6eSkhLdc889Gh4eTueujKstW7Zo4cKFiQ+eBoNB7dq1K7GeY3RhDz/8sBzH0fr16xPLOFbjxGSZrVu3Go/HY37605+aw4cPm9tuu80UFhaarq6uTE8tbXbu3Gn+4R/+wTz33HNGktm+fXvS+ocfftj4/X7z/PPPmzfffNP8xV/8hamsrDTnzp1LjPnKV75iFi1aZF5//XXz61//2lxxxRXm5ptvTvOejJ+amhrzs5/9zBw6dMi0tbWZr33ta6aiosL09fUlxtx+++2mvLzcNDY2mjfeeMMsX77cfP7zn0+sHx4eNvPnzzfV1dXmwIEDZufOnWbGjBlm48aNmdilcfHCCy+Y//iP/zD/9V//Zdrb2833v/99k5ubaw4dOmSM4RhdyN69e82cOXPMwoULzV133ZVYzrEaH1kXqWuuucbU1dUl/hyLxUxZWZmpr6/P4Kwy5+ORisfjprS01Dz66KOJZT09Pcbr9ZpnnnnGGGPM22+/bSSZffv2Jcbs2rXLOI5jjh07lra5p1N3d7eRZJqamowxHxyT3Nxcs23btsSYd955x0gyzc3NxpgP/mfA5XKZUCiUGLNlyxbj8/lMNBpN7w6k0fTp081PfvITjtEF9Pb2mrlz55qGhgZz/fXXJyLFsRo/WfV23+DgoFpbW1VdXZ1Y5nK5VF1drebm5gzOzB5Hjx5VKBRKOkZ+v1/Lli1LHKPm5mYVFhZqyZIliTHV1dVyuVxqaWlJ+5zTIRwOS/qfmxO3trZqaGgo6TjNmzdPFRUVScdpwYIFCgQCiTE1NTWKRCI6fPhwGmefHrFYTFu3blV/f7+CwSDH6ALq6upUW1ubdEwk/nsaT1l1g9lTp04pFosl/UuWpEAgoHfffTdDs7JLKBSSpAseo/PrQqGQSkpKkta73W4VFRUlxkwk8Xhc69ev17XXXqv58+dL+uAYeDweFRYWJo39+HG60HE8v26iOHjwoILBoAYGBlRQUKDt27erqqpKbW1tHKOP2Lp1q/bv3699+/Z9Yh3/PY2frIoUcCnq6up06NAhvfrqq5meipWuvPJKtbW1KRwO6xe/+IXWrFmjpqamTE/LKp2dnbrrrrvU0NCgKVOmZHo6k0pWvd03Y8YM5eTkfOKKma6uLpWWlmZoVnY5fxw+7RiVlpaqu7s7af3w8LDOnDkz4Y7junXrtGPHDu3Zs0ezZs1KLC8tLdXg4KB6enqSxn/8OF3oOJ5fN1F4PB5dccUVWrx4serr67Vo0SL98Ic/5Bh9RGtrq7q7u3X11VfL7XbL7XarqalJjz/+uNxutwKBAMdqnGRVpDwejxYvXqzGxsbEsng8rsbGRgWDwQzOzB6VlZUqLS1NOkaRSEQtLS2JYxQMBtXT06PW1tbEmN27dysej2vZsmVpn/N4MMZo3bp12r59u3bv3q3Kysqk9YsXL1Zubm7ScWpvb1dHR0fScTp48GBS0BsaGuTz+VRVVZWeHcmAeDyuaDTKMfqIFStW6ODBg2pra0s8lixZotWrVyf+mWM1TjJ95cZobd261Xi9XvPkk0+at99+23zve98zhYWFSVfMTHS9vb3mwIED5sCBA0aS+ad/+idz4MAB8/vf/94Y88El6IWFhebf//3fzVtvvWW++c1vXvAS9M997nOmpaXFvPrqq2bu3LkT6hL0O+64w/j9fvPyyy+bEydOJB5nz55NjLn99ttNRUWF2b17t3njjTdMMBg0wWAwsf78JcMrV640bW1t5pe//KW57LLLJtQlw/fee69pamoyR48eNW+99Za59957jeM45le/+pUxhmP0aT56dZ8xHKvxknWRMsaYf/7nfzYVFRXG4/GYa665xrz++uuZnlJa7dmzx0j6xGPNmjXGmA8uQ7///vtNIBAwXq/XrFixwrS3tydt4/Tp0+bmm282BQUFxufzmVtvvdX09vZmYG/Gx4WOjyTzs5/9LDHm3Llz5m//9m/N9OnTTV5envnWt75lTpw4kbSd9957z3z1q181U6dONTNmzDB/93d/Z4aGhtK8N+Pnb/7mb8zs2bONx+Mxl112mVmxYkUiUMZwjD7NxyPFsRoffFUHAMBaWXVOCgAwuRApAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgrf8PGTgF/XqcDJ8AAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "-783.7704941864631"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "student.test(play=True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Gym",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
